import torch
import os
from dataset import *
from torch.utils.data import DataLoader


def load_data():
    # load data
    tsv_store_dir = "/home/datanfs/macong_data/text_classification/trim_data"
    h5_train_pos = os.path.join(tsv_store_dir, "Chinese_conversation", "h5train.pkl")
    h5_valid_pos = os.path.join(tsv_store_dir, "Chinese_conversation", "h5valid.pkl")

    train_ = Mdataset(h5_train_pos)
    eval_ = Mdataset(h5_valid_pos)

    batch_size = 2
    train_loader = DataLoader(train_, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_, batch_size, shuffle=True)

    for iter, (src, targ, src_len) in enumerate(train_loader):
        print(iter)
        print(src)
        print(targ)
        print(src_len)
        print(len(train_loader))



if __name__ == '__main__':
    load_data()